import numpy as np
from pylab import *

def plot_loss_each_iter():
	# loss value each iter
	D = 10
	train_loss_cnn = np.load('cnn/result/d5/train_loss_iter.npy').reshape(D, -1).mean(axis=0)
	valid_loss_cnn = np.load('cnn/result/d5/valid_loss_iter.npy').reshape(D, -1).mean(axis=0)
	train_loss_mlp = np.load('mlp/result/d5/train_loss_iter.npy').reshape(D, -1).mean(axis=0)
	valid_loss_mlp = np.load('mlp/result/d5/valid_loss_iter.npy').reshape(D, -1).mean(axis=0)
	iter_ = np.arange(train_loss_cnn.shape[0] * D)
	figure()
	p = subplot(111)
	p.plot(iter_[::D], train_loss_cnn, '-', label='CNN train loss')
	p.plot(iter_[::5*D], valid_loss_cnn, '--', label='CNN valid loss')
	p.plot(iter_[::D], train_loss_mlp, '-', label='MLP train loss')
	p.plot(iter_[::5*D], valid_loss_mlp, '--', label='MLP valid loss')
	p.set_ylim((0, .6))
	p.set_xlabel(r'# of Iterations')
	p.set_ylabel(r'Loss')
	p.legend(loc='upper right')
	tight_layout()
	savefig("loss_iter.pdf")

def plot_comp(name, loss_lim, acc_lim):
	# full version comparison of MLP and CNN
	train_loss_cnn = np.load('cnn/result/%s/train_loss.npy' % name)
	train_loss_mlp = np.load('mlp/result/%s/train_loss.npy' % name)
	valid_loss_cnn = np.load('cnn/result/%s/val_loss.npy' % name)
	valid_loss_mlp = np.load('mlp/result/%s/val_loss.npy' % name)
	iter_ = np.arange(train_loss_cnn.shape[0])+1
	figure()
	p = subplot(111)
	p.plot(iter_, train_loss_cnn, '-', label='CNN train loss')
	p.plot(iter_, valid_loss_cnn, '--', label='CNN valid loss')
	p.plot(iter_, train_loss_mlp, '-', label='MLP train loss')
	p.plot(iter_, valid_loss_mlp, '--', label='MLP valid loss')
	p.set_xlim((1, 20))
	p.set_ylim((0, loss_lim))
	p.set_xlabel(r'# of Epochs')
	p.set_ylabel(r'Loss')
	p.legend(loc='upper right')
	tight_layout()
	savefig("loss_%s.pdf" % name)
	train_acc_cnn = np.load('cnn/result/%s/train_acc.npy' % name)
	train_acc_mlp = np.load('mlp/result/%s/train_acc.npy' % name)
	valid_acc_cnn = np.load('cnn/result/%s/val_acc.npy' % name)
	valid_acc_mlp = np.load('mlp/result/%s/val_acc.npy' % name)
	figure()
	p = subplot(111)
	p.plot(iter_, train_acc_cnn, '-', label='CNN train acc')
	p.plot(iter_, valid_acc_cnn, '--', label='CNN valid acc')
	p.plot(iter_, train_acc_mlp, '-', label='MLP train acc')
	p.plot(iter_, valid_acc_mlp, '--', label='MLP valid acc')
	p.set_xlim((1, 20))
	p.set_ylim((acc_lim, 1))
	p.set_xlabel(r'# of Epochs')
	p.set_ylabel(r'Accuracy')
	p.legend(loc='lower right')
	tight_layout()
	savefig("acc_%s.pdf" % name)

def plot_comp2(name='d5', name2='nobn'):
	# full version comparison of MLP and CNN
	figure()
	p = subplot(111)
	train_loss_cnn = np.load('cnn/result/%s/train_loss.npy' % name)
	valid_loss_cnn = np.load('cnn/result/%s/val_loss.npy' % name)
	iter_ = np.arange(train_loss_cnn.shape[0])+1
	p.plot(iter_, train_loss_cnn, '-', label='%s train loss' % 'BN')
	p.plot(iter_, valid_loss_cnn, '--', label='%s valid loss' % 'BN')
	train_loss_cnn = np.load('cnn/result/%s/train_loss.npy' % name2)
	valid_loss_cnn = np.load('cnn/result/%s/val_loss.npy' % name2)
	p.plot(iter_, train_loss_cnn, '-', label='%s train loss' % 'no BN')
	p.plot(iter_, valid_loss_cnn, '--', label='%s valid loss' % 'no BN')
	p.set_xlim((1, 20))
	p.set_ylim((0, .1))
	p.set_xlabel(r'# of Epochs')
	p.set_ylabel(r'Loss')
	p.legend(loc='upper right')
	tight_layout()
	savefig("loss_comp_CNN_%s_%s.pdf" % (name, name2))

	figure()
	p = subplot(111)
	train_loss_mlp = np.load('mlp/result/%s/train_loss.npy' % name)
	valid_loss_mlp = np.load('mlp/result/%s/val_loss.npy' % name)
	p.plot(iter_, train_loss_mlp, '-', label='BN train loss')
	p.plot(iter_, valid_loss_mlp, '--', label='BN valid loss')
	train_loss_mlp = np.load('mlp/result/%s/train_loss.npy' % name2)
	valid_loss_mlp = np.load('mlp/result/%s/val_loss.npy' % name2)
	p.plot(iter_, train_loss_mlp, '-', label='no BN train loss')
	p.plot(iter_, valid_loss_mlp, '--', label='no BN valid loss')
	p.set_xlim((1, 20))
	p.set_ylim((0, .8))
	p.set_xlabel(r'# of Epochs')
	p.set_ylabel(r'Loss')
	p.legend(loc='upper right')
	tight_layout()
	savefig("loss_comp_MLP_%s_%s.pdf" % (name, name2))



	figure()
	p = subplot(111)
	train_acc_cnn = np.load('cnn/result/%s/train_acc.npy' % name)
	valid_acc_cnn = np.load('cnn/result/%s/val_acc.npy' % name)
	p.plot(iter_, train_acc_cnn, '-', label='BN train acc')
	p.plot(iter_, valid_acc_cnn, '--', label='BN valid acc')
	train_acc_cnn = np.load('cnn/result/%s/train_acc.npy' % name2)
	valid_acc_cnn = np.load('cnn/result/%s/val_acc.npy' % name2)
	p.plot(iter_, train_acc_cnn, '-', label='no BN train acc')
	p.plot(iter_, valid_acc_cnn, '--', label='no BN valid acc')
	p.set_xlim((1, 20))
	p.set_ylim((.97, 1))
	p.set_xlabel(r'# of Epochs')
	p.set_ylabel(r'Accuracy')
	p.legend(loc='lower right')
	tight_layout()
	savefig("acc_comp_CNN_%s_%s.pdf" % (name, name2))

	figure()
	p = subplot(111)
	train_acc_mlp = np.load('mlp/result/%s/train_acc.npy' % name)
	valid_acc_mlp = np.load('mlp/result/%s/val_acc.npy' % name)
	p.plot(iter_, train_acc_mlp, '-', label='BN train acc')
	p.plot(iter_, valid_acc_mlp, '--', label='BN valid acc')
	train_acc_mlp = np.load('mlp/result/%s/train_acc.npy' % name2)
	valid_acc_mlp = np.load('mlp/result/%s/val_acc.npy' % name2)
	p.plot(iter_, train_acc_mlp, '-', label='no BN train acc')
	p.plot(iter_, valid_acc_mlp, '--', label='no BN valid acc')
	p.set_xlim((1, 20))
	p.set_ylim((.75, 1))
	p.set_xlabel(r'# of Epochs')
	p.set_ylabel(r'Accuracy')
	p.legend(loc='lower right')
	tight_layout()
	savefig("acc_comp_MLP_%s_%s.pdf" % (name, name2))

def plot_drop():
	train_loss_cnn = np.load('cnn/result/d5/train_loss.npy')
	iter_ = np.arange(train_loss_cnn.shape[0])+1
	for model in ['mlp', 'cnn']:
		for metric in ['acc', 'loss']:
			figure()
			p = subplot(111)
			for name in ['d1', 'd3', 'd5', 'd7', 'd9']:
				for face in ['train', 'val']:
					result = np.load('%s/result/%s/%s_%s.npy' % (model, name, face, metric))
					p.plot(iter_, result, '-' if face == 'train' else '--', label='%s %s %s' % (name, face, metric))

			p.set_xlim((1, 20))
			if metric == 'acc':
				p.set_ylim((.4 if model == 'mlp' else .93, 1))
				p.set_ylabel(r'Accuracy')
				p.legend(loc='lower right')
			else:
				p.set_ylim((0, 1.35 if model == 'mlp' else .5))
				p.set_ylabel(r'Loss')
				p.legend(loc='upper right')
			p.set_xlabel(r'# of Epochs')
			tight_layout()
			savefig("drop_%s_%s.pdf" % (model, metric))
	

if __name__ == '__main__':
 	plot_loss_each_iter()
 	plot_comp('d5', .6, .87)
 	plot_comp2()
 	# plot_drop()